#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Apr 19 15:54:38 2022

@author: qiguangyao
"""


#%%Lib
import copy
import numpy as np
import pickle as pkl
import matplotlib.pyplot as plt
from scipy import stats
import seaborn as sns 
from scipy import asarray as ar,exp
from scipy.optimize import curve_fit
import math
import pingouin as pg
from sklearn import linear_model
from pylab import cos
import pandas as pd
import random
from statsmodels.formula.api import ols
from statsmodels.stats.anova import anova_lm
from statsmodels.sandbox.stats.multicomp import multipletests # for multiple comparisons correction
from statsmodels.stats.multicomp import pairwise_tukeyhsd
print("__file Output:",__file__)
#%%functions
import scipy.stats
def mean_confidence_interval(data, confidence=0.95):
    a = 1.0 * np.array(data)
    n = len(a)
    m, se = np.mean(a), scipy.stats.sem(a)
    h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
    return m, m-h, m+h

def adjust_spines(ax, spines):
    for loc, spine in ax.spines.items():
        if loc in spines:
            spine.set_position(('outward', 10))  # outward by 10 points
        else:
            spine.set_color('none')  # don't draw spine

    # turn off ticks where there is no spine
    if 'left' in spines:
        ax.yaxis.set_ticks_position('left')
    else:
        # no yaxis ticks
        ax.yaxis.set_ticks([])

    if 'bottom' in spines:
        ax.xaxis.set_ticks_position('bottom')
    else:
        # no xaxis ticks
        ax.xaxis.set_ticks([])
        
def gaus(x,a,x0,sigma):
    return a*(1/sigma*np.sqrt(2*np.pi))*exp(-(x-x0)**2/(2*sigma**2))

def gaussian(X, amp, cen, wid):
    return amp * exp(-(X-cen)**2 / wid)

def getPossionPDF(mu,x):
    if x > 170:
        x =170
    mu = mu + 0.01
    if x<0:
        x = 0
    # x[x<0]=0
    x = copy.deepcopy(round(x))
    out = math.exp(-mu)*(mu**x)/math.factorial(x)
    if out<0:
        out = 0
    return out

#tuning curve fitting
def vonMisesFunction(x,b,a,u):
    # import math
#    print(x - u)
    out = b + a*cos(x - u)
    out = np.array(out)
    out[out<0]=0
    # if out<0:
    #     out = 0
    return out

def getvonMisesParas(x,y):
    """
    x:hand position
    y:firing rate
    """
    init_vals = [1, 0, 1]  # for [b,a,u]
    best_vals, covar = curve_fit(vonMisesFunction, x, y, p0=init_vals,maxfev=500000)
    return best_vals

def getExpParas(x,y):
    """
    x:hand position
    y:firing rate
    """
    init_vals = [1, 0, 1]  # for [b,a,u]
    best_vals, covar = curve_fit(expFunction, x, y, p0=init_vals,maxfev=500000)
    return best_vals

def expFunction(x, a, b, c):
    return a * np.exp(-b * x) + c

#%% ------------figure 5---------------- 
fig5Data = pkl.load(open('fig5Data.pickle','rb'))
#5A
downRelaDrifFRBin = fig5Data['downRelaDrifFRBin']
upRelaDrifFRBin = fig5Data['upRelaDrifFRBin']
neuralEventTargetOnset = fig5Data['neuralEventTargetOnset']
PEventIndex = fig5Data['PEventIndex']
smallRelaDrifEventIndex = fig5Data['smallRelaDrifEventIndex']
largeRelaDrifEventIndex = fig5Data['largeRelaDrifEventIndex']
VPHandEventIndex = fig5Data['VPHandEventIndex']
VPHandFRTemp = fig5Data['VPHandFRTemp']
PFRTemp = fig5Data['PFRTemp']
neurName = fig5Data['neurName']

#5B
meanDownUpPMCLag1 = fig5Data['meanDownUpPMCLag1']
meanDownUpArea5Lag1 = fig5Data['meanDownUpArea5Lag1']

#5C

PMCRefeCorrMeanUpLowLag = fig5Data['PMCRefeCorrMeanUpLowLag']
Area5RefeCorrMeanUpLowLag = fig5Data['Area5RefeCorrMeanUpLowLag']

#5E
plotPCs =fig5Data['plotPCs']
exp_var =fig5Data['exp_var']
pom_on_pri =fig5Data['pom_on_pri']
pri_on_pcom = fig5Data['pri_on_pcom']

#5F
corrRateScoreBinSmooPMCMeanPosterior = fig5Data['corrRateScoreBinSmooPMCMeanPosterior']
corrRateScoreBinSmooPMCMeanPrior = fig5Data['corrRateScoreBinSmooPMCMeanPrior']
#%%fig5A
widthRast = 0.005#0.000005
hightRast = 0.9#.000005

'#802629'
'#ff8fa0'
timeBins = [i/10-.8 for i in range(22)]
with plt.style.context('style_paper.mplstyle'):
    f, (ax1, ax2) = plt.subplots(ncols=1, nrows=2, sharex=True,
                                  figsize=[7.25/2.5*.83,3.54/1.5],#New3 20210706#7.25/3
                                  gridspec_kw={'height_ratios': [1, 1]})
    plotTrial = 1
    random.shuffle(PEventIndex)
    random.shuffle(smallRelaDrifEventIndex)
    random.shuffle(largeRelaDrifEventIndex)
    random.shuffle(VPHandEventIndex)
    for i in PEventIndex:
        for k in range(len(neuralEventTargetOnset[i,:])):
            rect = plt.Rectangle((neuralEventTargetOnset[i,k]-widthRast/2,plotTrial-hightRast/2),widthRast,hightRast, color ='gray')#colors[0])# '#e8503e',alpha = 1)#,color = '#ff4f0e')
            ax1.add_patch(rect)
        plotTrial += 1 
        
    for i in smallRelaDrifEventIndex:
        for k in range(len(neuralEventTargetOnset[i,:])):
            rect = plt.Rectangle((neuralEventTargetOnset[i,k]-widthRast/2,plotTrial-hightRast/2),widthRast,hightRast,color ='#1f77b4')#colors[1])# '#1226aa)#color = '#1f77b4')
            ax1.add_patch(rect)
        plotTrial += 1
    for i in largeRelaDrifEventIndex:
        for k in range(len(neuralEventTargetOnset[i,:])):
            rect = plt.Rectangle((neuralEventTargetOnset[i,k]-widthRast/2,plotTrial-hightRast/2),widthRast,hightRast, color ='#d62728')#colors[0])# '#e8503e',alpha = 1)#,color = '#ff4f0e')
            ax1.add_patch(rect)
        plotTrial += 1

    for i in VPHandEventIndex:
        for k in range(len(neuralEventTargetOnset[i,:])):
            rect = plt.Rectangle((neuralEventTargetOnset[i,k]-widthRast/2,plotTrial-hightRast/2),widthRast,hightRast, color ='k')#colors[0])# '#e8503e',alpha = 1)#,color = '#ff4f0e')
            ax1.add_patch(rect)
        plotTrial += 1    
    ax1.plot([-0.5 for jj in range(plotTrial+1)],[jj for jj in range(plotTrial+1)],'-',lw = 1,color = 'k')
    ax1.text(-.5, plotTrial+1, 'Disparity onset', horizontalalignment='center',size=7,color = 'k')
    ax1.set_ylabel('Trial #')
    ax1.set_xticks(np.arange(-1, 1.51, step=0.5))
    ax1.set_xlim([-.9, 1.4])

    ax2.set_xlabel('Time from target onset (s)')
    ax1.set_ylim([0,plotTrial+1])
    ax2.plot([i/10-.8 for i in [0,1,2]],[15 for i in [0,1,2]],'-',color = 'k',alpha = 1)

    ax2.plot([i/10-.8 for i in range(22)],np.nanmean(upRelaDrifFRBin,axis = 0),color = '#d62728',lw = 1,label = 'High prior')
    ax2.fill_between([i/10-.8 for i in range(22)],np.nanmean(upRelaDrifFRBin,axis = 0)-np.nanstd(upRelaDrifFRBin,axis = 0)/np.sqrt(upRelaDrifFRBin.shape[0]),
              np.nanmean(upRelaDrifFRBin,axis = 0)+np.nanstd(upRelaDrifFRBin,axis = 0)/np.sqrt(upRelaDrifFRBin.shape[0]),
              edgecolor='#d62728', facecolor='#d62728',alpha=0.5)
    
    ax2.plot([i/10-.8 for i in range(22)],np.nanmean(downRelaDrifFRBin,axis = 0),color = '#1f77b4',lw = 1,label = 'Low prior')
    ax2.fill_between([i/10-.8 for i in range(22)],np.nanmean(downRelaDrifFRBin,axis = 0)-np.nanstd(downRelaDrifFRBin,axis = 0)/np.sqrt(downRelaDrifFRBin.shape[0]),
              np.nanmean(downRelaDrifFRBin,axis = 0)+np.nanstd(downRelaDrifFRBin,axis = 0)/np.sqrt(downRelaDrifFRBin.shape[0]),
              edgecolor='#1f77b4',   facecolor='#1f77b4',alpha=0.5)
    ax2.set_ylabel('Firing rate (Hz)')

    ax2.plot(timeBins,np.nanmean(VPHandFRTemp,0),label = 'VP',color = 'k')#
    ax2.fill_between([i/10-.8 for i in range(22)],np.nanmean(VPHandFRTemp,axis = 0)-np.nanstd(VPHandFRTemp,axis = 0)/np.sqrt(VPHandFRTemp.shape[0]),
              np.nanmean(VPHandFRTemp,axis = 0)+np.nanstd(VPHandFRTemp,axis = 0)/np.sqrt(VPHandFRTemp.shape[0]),
              edgecolor='k',   facecolor='k',alpha=0.5)
    ax2.plot(timeBins,np.nanmean(PFRTemp,0),label = 'P',color = 'gray')#
    ax2.fill_between([i/10-.8 for i in range(22)],np.nanmean(PFRTemp,axis = 0)-np.nanstd(PFRTemp,axis = 0)/np.sqrt(PFRTemp.shape[0]),
              np.nanmean(PFRTemp,axis = 0)+np.nanstd(PFRTemp,axis = 0)/np.sqrt(PFRTemp.shape[0]),
              edgecolor='gray',   facecolor='gray',alpha=0.5)    
    # ax2.legend(bbox_to_anchor = [.4,.5],ncol=2,handlelength = .75)
    plt.tight_layout()
    fileName = 'fig5A_PMC'+'ExamplePriorRasterTraj'+neurName+'.pdf'
    # plt.savefig(fileName,dpi = 600)
plt.show()
#%%fig5B
with plt.style.context('style_paper.mplstyle'):
    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"][3:]
    f, ax1 = plt.subplots(ncols=1, nrows=1, sharex=True,sharey=True,figsize=[7.25/3,3.54/1.5])

    ax1.plot([i*.1-.8 for i in range(0,5)],[.58 for i in range(0,5) ],'-',color=colors[0],linewidth = 1)
    # corrRateScoreBinSmooPMCMean =  np.nanmean(corrRateScoreBinSmooPMCLagHand[lag]['corrRateScoreBinSmooPMC'],axis = 0)
    ax1.plot([i*.1-.8 for i in range(22)],meanDownUpPMCLag1[0,:], color = colors[0],label = 'Premotor',linewidth = 1)    
    ax1.fill_between([i*.1-.8 for i in range(22)],meanDownUpPMCLag1[1,:],meanDownUpPMCLag1[2,:],alpha = .2,color = colors[0])
    
    #Area5
    # corrRateScoreBinSmooArea5Mean = np.nanmean(corrRateScoreBinSmooArea5LagHand[lag]['corrRateScoreBinSmooArea5'],axis = 0)
    ax1.plot([i*.1-.8 for i in range(22)],meanDownUpArea5Lag1[0,:], color = colors[1],label = 'Parietal',linewidth = 1)
    ax1.fill_between([i*.1-.8 for i in range(22)],meanDownUpArea5Lag1[1,:],meanDownUpArea5Lag1[2,:],alpha = .2,color = colors[1])
    plt.legend(loc = 'best')
    plt.fill_between([i*.1-.8 for i in range(0,4)],
                          [.47 for k in range(0,4)],
                          [.575 for k in range(0,4)],
                         edgecolor = [],facecolor = 'gray',alpha = 0.4)
    ax1.plot([i*.1-.8 for i in range(22)],[.5 for i in range(22)], color = 'k',ls= '--')
    ax1.set_xticks(np.arange(-1, 1.51, step=0.5))
    plt.xlim([-.9,1.4])
    plt.ylim(bottom = 0.465)
    ax1.set_xlabel('Time from target onset (s)')
    ax1.set_ylabel("Decoding accuracy")
    plt.tight_layout()
    fileName = 'fig5B_'+'decoPrevRelaDrifByFR.pdf'
    # plt.savefig(fileName,dpi = 600)
plt.show()
#%%fig5C
paraExpPMC = getExpParas([i for i in range(1,4)],PMCRefeCorrMeanUpLowLag[0,:][range(1,4)])
xPMC = np.linspace(1,4,1000)
x = xPMC
yPMC =  expFunction(x, paraExpPMC[0],paraExpPMC[1],paraExpPMC[2])
paraExpArea5 = getExpParas([i for i in range(1,4)],Area5RefeCorrMeanUpLowLag[0,:][range(1,4)])
xArea5 = np.linspace(1,4,1000)
yArea5 =  expFunction(x, paraExpArea5[0],paraExpArea5[1],paraExpArea5[2])
markSize=8
with plt.style.context('style_paper.mplstyle'):
    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
    plt.figure(figsize = [7.25/3,3.54/1.5])
    plt.errorbar([i for i in range(5)], PMCRefeCorrMeanUpLowLag[0,:],yerr =  PMCRefeCorrMeanUpLowLag[1,:] - PMCRefeCorrMeanUpLowLag[0,:],fmt='.',color = colors[3],lw = 1, elinewidth = 1,label = 'Premotor',markersize=markSize )
    plt.plot(xPMC,yPMC,color = colors[3],ls = '-')
    plt.plot(np.linspace(-.1,4.1,10),[0.5 for i in range(10)],color = 'k',ls = '--')
    plt.errorbar([i for i in range(5)], Area5RefeCorrMeanUpLowLag[0,:],yerr = Area5RefeCorrMeanUpLowLag[1,:] - Area5RefeCorrMeanUpLowLag[0,:],fmt='.',color = colors[4],lw = 1, elinewidth = 1,label = 'Parietal',markersize=markSize )    
    plt.plot(xArea5,yArea5,color = colors[4],ls = '-')
    plt.xticks(np.arange(0, 5, step=1))
    plt.yticks(np.arange(0.44, .59, step=.02))
    plt.ylim([.46,.565])
    # plt.text(1,0.568,s = '***',horizontalalignment='center')
    plt.xlabel('Trial lag')
    plt.legend()
    plt.ylabel('Decoding accuracy')
    plt.tight_layout()
    fileName = 'fig5C_decoPrevRelaDrifByFRLag.pdf'
    # plt.savefig(fileName,dpi = 600)
plt.show()
#%%fig5E
print(sum(pom_on_pri),sum(pri_on_pcom),sum(exp_var[1][:plotPCs]),sum(exp_var[0][:plotPCs]))
with plt.style.context('style_paper.mplstyle'):    
    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
    f, (ax1,ax2) = plt.subplots(ncols=2, nrows=1, sharex=True,sharey=False,figsize=[2.2*1.5,2.4])
    ax1.bar(np.arange(1,plotPCs+1), exp_var[0][:plotPCs], width=0.3, color='k')
    ax1.bar(np.arange(1,plotPCs+1)+0.3, pom_on_pri, width=0.3, color='gray')
    ax1.set_xlabel('$Prior$'+' PCs')
    ax1.set_ylabel('Explained Var.(%)')
    ax1.legend(['$Prior$','$P_{com}$'])
    ax2.bar(np.arange(1,plotPCs+1), pri_on_pcom, width=0.3, color='k')
    ax2.bar(np.arange(1,plotPCs+1)+0.3, exp_var[1][:plotPCs], width=0.3, color='gray')
    ax2.set_xlabel('$P_{com}$ PCs')
    ax2.legend(['$Prior$','$P_{com}$'])
    plt.tight_layout()
    fileName = 'fig5E_pri_pcom_proj.pdf'
    # plt.savefig(fileName,dpi = 600)
plt.show()
#%%fig5F
with plt.style.context('style_paper.mplstyle'):
    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
    fig = plt.figure(figsize=[7.25/3+.5,3.54/1.5])
    ax1 = fig.add_subplot(111)
    ax1.plot([i*.1-.8 for i in range(22)], corrRateScoreBinSmooPMCMeanPrior, color = colors[3])
    ax1.plot([i*.1-.8 for i in range(5)], [0.56 for i in range(5)], color = colors[3])
    ax1.set_ylabel('Prior decoding accuracy',color=colors[3])
    ax1.set_xlabel('Time from target onset (s)')    
    ax1.tick_params(axis='y', labelcolor= colors[3])
    ax2 = ax1.twinx()  # this is the important function
    ax2.plot([i*.1-.8 for i in range(22)], corrRateScoreBinSmooPMCMeanPosterior, color = 'k')
    ax2.plot([i*.1-.8 for i in range(7,22)],[.9 for i in range(7,22) ],'-',color='k',linewidth = 1) #7
    ax2.tick_params(axis='y', labelcolor='k')
    ax2.set_ylabel('Posterior decoding accuracy ',color='k')
    plt.xticks(np.arange(-1, 1.51, step=.5))
    plt.xlim([-.9,1.4])
    plt.tight_layout()
    fileName = 'fig5F_PMC'+'priorPosterior.pdf'
    # plt.savefig(fileName,dpi = 600)
plt.show()
